# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Anthropic integration for Claude models."""

from __future__ import annotations

from functools import cached_property
import logging
import os
from typing import AsyncGenerator
from typing import Generator
from typing import Iterable
from typing import Literal
from typing import Optional, Union
from typing import TYPE_CHECKING

from anthropic import AnthropicVertex
from anthropic import NOT_GIVEN
from anthropic import types as anthropic_types
from google.genai import types
from pydantic import BaseModel
from typing_extensions import override

from .base_llm import BaseLlm
from .llm_response import LlmResponse

if TYPE_CHECKING:
  from .llm_request import LlmRequest

__all__ = ["Claude"]

logger = logging.getLogger(__name__)

MAX_TOKEN = 1024


class ClaudeRequest(BaseModel):
  system_instruction: str
  messages: Iterable[anthropic_types.MessageParam]
  tools: list[anthropic_types.ToolParam]


def to_claude_role(role: Optional[str]) -> Literal["user", "assistant"]:
  if role in ["model", "assistant"]:
    return "assistant"
  return "user"


def to_google_genai_finish_reason(
    anthropic_stop_reason: Optional[str],
) -> types.FinishReason:
  if anthropic_stop_reason in ["end_turn", "stop_sequence", "tool_use"]:
    return "STOP"
  if anthropic_stop_reason == "max_tokens":
    return "MAX_TOKENS"
  return "FINISH_REASON_UNSPECIFIED"


def part_to_message_block(
    part: types.Part,
) -> Union[
    anthropic_types.TextBlockParam,
    anthropic_types.ImageBlockParam,
    anthropic_types.ToolUseBlockParam,
    anthropic_types.ToolResultBlockParam,
]:
  if part.text:
    return anthropic_types.TextBlockParam(text=part.text, type="text")
  if part.function_call:
    assert part.function_call.name

    return anthropic_types.ToolUseBlockParam(
        id=part.function_call.id or "",
        name=part.function_call.name,
        input=part.function_call.args,
        type="tool_use",
    )
  if part.function_response:
    content = ""
    if (
        "result" in part.function_response.response
        and part.function_response.response["result"]
    ):
      # Transformation is required because the content is a list of dict.
      # ToolResultBlockParam content doesn't support list of dict. Converting
      # to str to prevent anthropic.BadRequestError from being thrown.
      content = str(part.function_response.response["result"])
    return anthropic_types.ToolResultBlockParam(
        tool_use_id=part.function_response.id or "",
        type="tool_result",
        content=content,
        is_error=False,
    )
  raise NotImplementedError("Not supported yet.")


def content_to_message_param(
    content: types.Content,
) -> anthropic_types.MessageParam:
  return {
      "role": to_claude_role(content.role),
      "content": [part_to_message_block(part) for part in content.parts or []],
  }


def content_block_to_part(
    content_block: anthropic_types.ContentBlock,
) -> types.Part:
  if isinstance(content_block, anthropic_types.TextBlock):
    return types.Part.from_text(text=content_block.text)
  if isinstance(content_block, anthropic_types.ToolUseBlock):
    assert isinstance(content_block.input, dict)
    part = types.Part.from_function_call(
        name=content_block.name, args=content_block.input
    )
    part.function_call.id = content_block.id
    return part
  raise NotImplementedError("Not supported yet.")


def message_to_generate_content_response(
    message: anthropic_types.Message,
) -> LlmResponse:

  return LlmResponse(
      content=types.Content(
          role="model",
          parts=[content_block_to_part(cb) for cb in message.content],
      ),
      # TODO: Deal with these later.
      # finish_reason=to_google_genai_finish_reason(message.stop_reason),
      # usage_metadata=types.GenerateContentResponseUsageMetadata(
      #     prompt_token_count=message.usage.input_tokens,
      #     candidates_token_count=message.usage.output_tokens,
      #     total_token_count=(
      #         message.usage.input_tokens + message.usage.output_tokens
      #     ),
      # ),
  )


def function_declaration_to_tool_param(
    function_declaration: types.FunctionDeclaration,
) -> anthropic_types.ToolParam:
  assert function_declaration.name

  properties = {}
  if (
      function_declaration.parameters
      and function_declaration.parameters.properties
  ):
    for key, value in function_declaration.parameters.properties.items():
      value_dict = value.model_dump(exclude_none=True)
      if "type" in value_dict:
        value_dict["type"] = value_dict["type"].lower()
      properties[key] = value_dict

  return anthropic_types.ToolParam(
      name=function_declaration.name,
      description=function_declaration.description or "",
      input_schema={
          "type": "object",
          "properties": properties,
      },
  )


class Claude(BaseLlm):
  model: str = "claude-3-5-sonnet-v2@20241022"

  @staticmethod
  @override
  def supported_models() -> list[str]:
    return [r"claude-3-.*"]

  @override
  async def generate_content_async(
      self, llm_request: LlmRequest, stream: bool = False
  ) -> AsyncGenerator[LlmResponse, None]:
    messages = [
        content_to_message_param(content)
        for content in llm_request.contents or []
    ]
    tools = NOT_GIVEN
    if (
        llm_request.config
        and llm_request.config.tools
        and llm_request.config.tools[0].function_declarations
    ):
      tools = [
          function_declaration_to_tool_param(tool)
          for tool in llm_request.config.tools[0].function_declarations
      ]
    tool_choice = (
        anthropic_types.ToolChoiceAutoParam(
            type="auto",
            # TODO: allow parallel tool use.
            disable_parallel_tool_use=True,
        )
        if llm_request.tools_dict
        else NOT_GIVEN
    )
    message = self._anthropic_client.messages.create(
        model=llm_request.model,
        system=llm_request.config.system_instruction,
        messages=messages,
        tools=tools,
        tool_choice=tool_choice,
        max_tokens=MAX_TOKEN,
    )
    logger.info(
        "Claude response: %s",
        message.model_dump_json(indent=2, exclude_none=True),
    )
    yield message_to_generate_content_response(message)

  @cached_property
  def _anthropic_client(self) -> AnthropicVertex:
    if (
        "GOOGLE_CLOUD_PROJECT" not in os.environ
        or "GOOGLE_CLOUD_LOCATION" not in os.environ
    ):
      raise ValueError(
          "GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION must be set for using"
          " Anthropic on Vertex."
      )

    return AnthropicVertex(
        project_id=os.environ["GOOGLE_CLOUD_PROJECT"],
        region=os.environ["GOOGLE_CLOUD_LOCATION"],
    )
